#%%
import torch
x = torch.randn((3,3), dtype=torch.float32, device="cuda:0")
y = x.detach().numpy()
# %%
